{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#skip\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp layers\n",
    "# default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.imports import *\n",
    "from fastai.torch_imports import *\n",
    "from fastai.torch_core import *\n",
    "from torch.nn.utils import weight_norm, spectral_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Layers\n",
    "> Custom fastai layers and basic functions to grab them."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Basic manipulations and resize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def module(*flds, **defaults):\n",
    "    \"Decorator to create an `nn.Module` using `f` as `forward` method\"\n",
    "    pa = [inspect.Parameter(o, inspect.Parameter.POSITIONAL_OR_KEYWORD) for o in flds]\n",
    "    pb = [inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=v)\n",
    "          for k,v in defaults.items()]\n",
    "    params = pa+pb\n",
    "    all_flds = [*flds,*defaults.keys()]\n",
    "\n",
    "    def _f(f):\n",
    "        class c(nn.Module):\n",
    "            def __init__(self, *args, **kwargs):\n",
    "                super().__init__()\n",
    "                for i,o in enumerate(args): kwargs[all_flds[i]] = o\n",
    "                kwargs = merge(defaults,kwargs)\n",
    "                for k,v in kwargs.items(): setattr(self,k,v)\n",
    "            __repr__ = basic_repr(all_flds)\n",
    "            forward = f\n",
    "        c.__signature__ = inspect.Signature(params)\n",
    "        c.__name__ = c.__qualname__ = f.__name__\n",
    "        c.__doc__  = f.__doc__\n",
    "        return c\n",
    "    return _f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@module()\n",
    "def Identity(self, x):\n",
    "    \"Do nothing at all\"\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(Identity()(1), 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "@module('func')\n",
    "def Lambda(self, x):\n",
    "    \"An easy way to create a pytorch layer for a simple `func`\"\n",
    "    return self.func(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Lambda(func=<function _add2 at 0x7f2b98416d30>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def _add2(x): return x+2\n",
    "tst = Lambda(_add2)\n",
    "x = torch.randn(10,20)\n",
    "test_eq(tst(x), x+2)\n",
    "tst2 = pickle.loads(pickle.dumps(tst))\n",
    "test_eq(tst2(x), x+2)\n",
    "tst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class PartialLambda(Lambda):\n",
    "    \"Layer that applies `partial(func, **kwargs)`\"\n",
    "    def __init__(self, func, **kwargs):\n",
    "        super().__init__(partial(func, **kwargs))\n",
    "        self.repr = f'{func.__name__}, {kwargs}'\n",
    "\n",
    "    def forward(self, x): return self.func(x)\n",
    "    def __repr__(self): return f'{self.__class__.__name__}({self.repr})'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_func(a,b=2): return a+b\n",
    "tst = PartialLambda(test_func, b=5)\n",
    "test_eq(tst(x), x+5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "@module(full=False)\n",
    "def Flatten(self, x):\n",
    "    \"Flatten `x` to a single dimension, e.g. at end of a model. `full` for rank-1 tensor\"\n",
    "    return TensorBase(x.view(-1) if self.full else x.view(x.size(0), -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = Flatten()\n",
    "x = torch.randn(10,5,4)\n",
    "test_eq(tst(x).shape, [10,20])\n",
    "tst = Flatten(full=True)\n",
    "test_eq(tst(x).shape, [200])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@module(tensor_cls=TensorBase)\n",
    "def ToTensorBase(self, x):\n",
    "    \"Remove `tensor_cls` to x\"\n",
    "    return self.tensor_cls(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ttb = ToTensorBase()\n",
    "timg = TensorImage(torch.rand(1,3,32,32))\n",
    "test_eq(type(ttb(timg)), TensorBase)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class View(Module):\n",
    "    \"Reshape `x` to `size`\"\n",
    "    def __init__(self, *size): self.size = size\n",
    "    def forward(self, x): return x.view(self.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = View(10,5,4)\n",
    "test_eq(tst(x).shape, [10,5,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class ResizeBatch(Module):\n",
    "    \"Reshape `x` to `size`, keeping batch dim the same size\"\n",
    "    def __init__(self, *size): self.size = size\n",
    "    def forward(self, x): return x.view((x.size(0),) + self.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = ResizeBatch(5,4)\n",
    "test_eq(tst(x).shape, [10,5,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "@module()\n",
    "def Debugger(self,x):\n",
    "    \"A module to debug inside a model.\"\n",
    "    set_trace()\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def sigmoid_range(x, low, high):\n",
    "    \"Sigmoid function with range `(low, high)`\"\n",
    "    return torch.sigmoid(x) * (high - low) + low"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test = tensor([-10.,0.,10.])\n",
    "assert torch.allclose(sigmoid_range(test, -1,  2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)\n",
    "assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)\n",
    "assert torch.allclose(sigmoid_range(test,  2,  4), tensor([2.,  3., 4.]), atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "@module('low','high')\n",
    "def SigmoidRange(self, x):\n",
    "    \"Sigmoid module with range `(low, high)`\"\n",
    "    return sigmoid_range(x, self.low, self.high)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = SigmoidRange(-1, 2)\n",
    "assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pooling layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class AdaptiveConcatPool1d(Module):\n",
    "    \"Layer that concats `AdaptiveAvgPool1d` and `AdaptiveMaxPool1d`\"\n",
    "    def __init__(self, size=None):\n",
    "        self.size = size or 1\n",
    "        self.ap = nn.AdaptiveAvgPool1d(self.size)\n",
    "        self.mp = nn.AdaptiveMaxPool1d(self.size)\n",
    "    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class AdaptiveConcatPool2d(Module):\n",
    "    \"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`\"\n",
    "    def __init__(self, size=None):\n",
    "        self.size = size or 1\n",
    "        self.ap = nn.AdaptiveAvgPool2d(self.size)\n",
    "        self.mp = nn.AdaptiveMaxPool2d(self.size)\n",
    "    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the input is `bs x nf x h x h`, the output will be `bs x 2*nf x 1 x 1` if no size is passed or `bs x 2*nf x size x size`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = AdaptiveConcatPool2d()\n",
    "x = torch.randn(10,5,4,4)\n",
    "test_eq(tst(x).shape, [10,10,1,1])\n",
    "max1 = torch.max(x,    dim=2, keepdim=True)[0]\n",
    "maxp = torch.max(max1, dim=3, keepdim=True)[0]\n",
    "test_eq(tst(x)[:,:5], maxp)\n",
    "test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))\n",
    "tst = AdaptiveConcatPool2d(2)\n",
    "test_eq(tst(x).shape, [10,10,2,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class PoolType: Avg,Max,Cat = 'Avg','Max','Cat'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def adaptive_pool(pool_type):\n",
    "    return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class PoolFlatten(nn.Sequential):\n",
    "    \"Combine `nn.AdaptiveAvgPool2d` and `Flatten`.\"\n",
    "    def __init__(self, pool_type=PoolType.Avg): super().__init__(adaptive_pool(pool_type)(1), Flatten())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = PoolFlatten()\n",
    "test_eq(tst(x).shape, [10,5])\n",
    "test_eq(tst(x), x.mean(dim=[2,3]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BatchNorm layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "NormType = Enum('NormType', 'Batch BatchZero Weight Spectral Instance InstanceZero')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _get_norm(prefix, nf, ndim=2, zero=False, **kwargs):\n",
    "    \"Norm layer with `nf` features and `ndim` initialized depending on `norm_type`.\"\n",
    "    assert 1 <= ndim <= 3\n",
    "    bn = getattr(nn, f\"{prefix}{ndim}d\")(nf, **kwargs)\n",
    "    if bn.affine:\n",
    "        bn.bias.data.fill_(1e-3)\n",
    "        bn.weight.data.fill_(0. if zero else 1.)\n",
    "    return bn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(nn.BatchNorm2d)\n",
    "def BatchNorm(nf, ndim=2, norm_type=NormType.Batch, **kwargs):\n",
    "    \"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`.\"\n",
    "    return _get_norm('BatchNorm', nf, ndim, zero=norm_type==NormType.BatchZero, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(nn.InstanceNorm2d)\n",
    "def InstanceNorm(nf, ndim=2, norm_type=NormType.Instance, affine=True, **kwargs):\n",
    "    \"InstanceNorm layer with `nf` features and `ndim` initialized depending on `norm_type`.\"\n",
    "    return _get_norm('InstanceNorm', nf, ndim, zero=norm_type==NormType.InstanceZero, affine=affine, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`kwargs` are passed to `nn.BatchNorm` and can be `eps`, `momentum`, `affine` and `track_running_stats`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = BatchNorm(15)\n",
    "assert isinstance(tst, nn.BatchNorm2d)\n",
    "test_eq(tst.weight, torch.ones(15))\n",
    "tst = BatchNorm(15, norm_type=NormType.BatchZero)\n",
    "test_eq(tst.weight, torch.zeros(15))\n",
    "tst = BatchNorm(15, ndim=1)\n",
    "assert isinstance(tst, nn.BatchNorm1d)\n",
    "tst = BatchNorm(15, ndim=3)\n",
    "assert isinstance(tst, nn.BatchNorm3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = InstanceNorm(15)\n",
    "assert isinstance(tst, nn.InstanceNorm2d)\n",
    "test_eq(tst.weight, torch.ones(15))\n",
    "tst = InstanceNorm(15, norm_type=NormType.InstanceZero)\n",
    "test_eq(tst.weight, torch.zeros(15))\n",
    "tst = InstanceNorm(15, ndim=1)\n",
    "assert isinstance(tst, nn.InstanceNorm1d)\n",
    "tst = InstanceNorm(15, ndim=3)\n",
    "assert isinstance(tst, nn.InstanceNorm3d)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If `affine` is false the weight should be `None`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(BatchNorm(15, affine=False).weight, None)\n",
    "test_eq(InstanceNorm(15, affine=False).weight, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class BatchNorm1dFlat(nn.BatchNorm1d):\n",
    "    \"`nn.BatchNorm1d`, but first flattens leading dimensions\"\n",
    "    def forward(self, x):\n",
    "        if x.dim()==2: return super().forward(x)\n",
    "        *f,l = x.shape\n",
    "        x = x.contiguous().view(-1,l)\n",
    "        return super().forward(x).view(*f,l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = BatchNorm1dFlat(15)\n",
    "x = torch.randn(32, 64, 15)\n",
    "y = tst(x)\n",
    "mean = x.mean(dim=[0,1])\n",
    "test_close(tst.running_mean, 0*0.9 + mean*0.1)\n",
    "var = (x-mean).pow(2).mean(dim=[0,1])\n",
    "test_close(tst.running_var, 1*0.9 + var*0.1, eps=1e-4)\n",
    "test_close(y, (x-mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class LinBnDrop(nn.Sequential):\n",
    "    \"Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers\"\n",
    "    def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):\n",
    "        layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []\n",
    "        if p != 0: layers.append(nn.Dropout(p))\n",
    "        lin = [nn.Linear(n_in, n_out, bias=not bn)]\n",
    "        if act is not None: lin.append(act)\n",
    "        layers = lin+layers if lin_first else layers+lin\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `BatchNorm` layer is skipped if `bn=False`, as is the dropout if `p=0.`. Optionally, you can add an activation for after the linear layer with `act`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = LinBnDrop(10, 20)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 2)\n",
    "assert isinstance(mods[0], nn.BatchNorm1d)\n",
    "assert isinstance(mods[1], nn.Linear)\n",
    "\n",
    "tst = LinBnDrop(10, 20, p=0.1)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 3)\n",
    "assert isinstance(mods[0], nn.BatchNorm1d)\n",
    "assert isinstance(mods[1], nn.Dropout)\n",
    "assert isinstance(mods[2], nn.Linear)\n",
    "\n",
    "tst = LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 3)\n",
    "assert isinstance(mods[0], nn.Linear)\n",
    "assert isinstance(mods[1], nn.ReLU)\n",
    "assert isinstance(mods[2], nn.BatchNorm1d)\n",
    "\n",
    "tst = LinBnDrop(10, 20, bn=False)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 1)\n",
    "assert isinstance(mods[0], nn.Linear)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def sigmoid(input, eps=1e-7):\n",
    "    \"Same as `torch.sigmoid`, plus clamping to `(eps,1-eps)\"\n",
    "    return input.sigmoid().clamp(eps,1-eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def sigmoid_(input, eps=1e-7):\n",
    "    \"Same as `torch.sigmoid_`, plus clamping to `(eps,1-eps)\"\n",
    "    return input.sigmoid_().clamp_(eps,1-eps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from torch.nn.init import kaiming_uniform_,uniform_,xavier_uniform_,normal_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def vleaky_relu(input, inplace=True):\n",
    "    \"`F.leaky_relu` with 0.3 slope\"\n",
    "    return F.leaky_relu(input, negative_slope=0.3, inplace=inplace)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "for o in F.relu,nn.ReLU,F.relu6,nn.ReLU6,F.leaky_relu,nn.LeakyReLU:\n",
    "    o.__default_init__ = kaiming_uniform_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "for o in F.sigmoid,nn.Sigmoid,F.tanh,nn.Tanh,sigmoid,sigmoid_:\n",
    "    o.__default_init__ = xavier_uniform_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def init_default(m, func=nn.init.kaiming_normal_):\n",
    "    \"Initialize `m` weights with `func` and set `bias` to 0.\"\n",
    "    if func and hasattr(m, 'weight'): func(m.weight)\n",
    "    with torch.no_grad():\n",
    "        if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)\n",
    "    return m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def init_linear(m, act_func=None, init='auto', bias_std=0.01):\n",
    "    if getattr(m,'bias',None) is not None and bias_std is not None:\n",
    "        if bias_std != 0: normal_(m.bias, 0, bias_std)\n",
    "        else: m.bias.data.zero_()\n",
    "    if init=='auto':\n",
    "        if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_\n",
    "        else: init = getattr(act_func.__class__, '__default_init__', None)\n",
    "        if init is None: init = getattr(act_func, '__default_init__', None)\n",
    "    if init is not None: init(m.weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convolutions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _conv_func(ndim=2, transpose=False):\n",
    "    \"Return the proper conv `ndim` function, potentially `transposed`.\"\n",
    "    assert 1 <= ndim <=3\n",
    "    return getattr(nn, f'Conv{\"Transpose\" if transpose else \"\"}{ndim}d')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "test_eq(_conv_func(ndim=1),torch.nn.modules.conv.Conv1d)\n",
    "test_eq(_conv_func(ndim=2),torch.nn.modules.conv.Conv2d)\n",
    "test_eq(_conv_func(ndim=3),torch.nn.modules.conv.Conv3d)\n",
    "test_eq(_conv_func(ndim=1, transpose=True),torch.nn.modules.conv.ConvTranspose1d)\n",
    "test_eq(_conv_func(ndim=2, transpose=True),torch.nn.modules.conv.ConvTranspose2d)\n",
    "test_eq(_conv_func(ndim=3, transpose=True),torch.nn.modules.conv.ConvTranspose3d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "defaults.activation=nn.ReLU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class ConvLayer(nn.Sequential):\n",
    "    \"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers.\"\n",
    "    @delegates(nn.Conv2d)\n",
    "    def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,\n",
    "                 act_cls=defaults.activation, transpose=False, init='auto', xtra=None, bias_std=0.01, **kwargs):\n",
    "        if padding is None: padding = ((ks-1)//2 if not transpose else 0)\n",
    "        bn = norm_type in (NormType.Batch, NormType.BatchZero)\n",
    "        inn = norm_type in (NormType.Instance, NormType.InstanceZero)\n",
    "        if bias is None: bias = not (bn or inn)\n",
    "        conv_func = _conv_func(ndim, transpose=transpose)\n",
    "        conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs)\n",
    "        act = None if act_cls is None else act_cls()\n",
    "        init_linear(conv, act, init=init, bias_std=bias_std)\n",
    "        if   norm_type==NormType.Weight:   conv = weight_norm(conv)\n",
    "        elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n",
    "        layers = [conv]\n",
    "        act_bn = []\n",
    "        if act is not None: act_bn.append(act)\n",
    "        if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))\n",
    "        if inn: act_bn.append(InstanceNorm(nf, norm_type=norm_type, ndim=ndim))\n",
    "        if bn_1st: act_bn.reverse()\n",
    "        layers += act_bn\n",
    "        if xtra: layers.append(xtra)\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The convolution uses `ks` (kernel size) `stride`, `padding` and `bias`. `padding` will default to the appropriate value (`(ks-1)//2` if it's not a transposed conv) and `bias` will default to `True` the `norm_type` is `Spectral` or `Weight`, `False` if it's `Batch` or `BatchZero`. Note that if you don't want any normalization, you should pass `norm_type=None`.\n",
    "\n",
    "This defines a conv layer with `ndim` (1,2 or 3) that will be a ConvTranspose if `transpose=True`. `act_cls` is the class of the activation function to use (instantiated inside). Pass `act=None` if you don't want an activation function. If you quickly want to change your default activation, you can change the value of `defaults.activation`.\n",
    "\n",
    "`init` is used to initialize the weights (the bias are initialized to 0) and `xtra` is an optional layer to add at the end."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = ConvLayer(16, 32)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 3)\n",
    "test_eq(mods[1].weight, torch.ones(32))\n",
    "test_eq(mods[0].padding, (1,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.randn(64, 16, 8, 8)#.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Padding is selected to make the shape the same if stride=1\n",
    "test_eq(tst(x).shape, [64,32,8,8])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Padding is selected to make the shape half if stride=2\n",
    "tst = ConvLayer(16, 32, stride=2)\n",
    "test_eq(tst(x).shape, [64,32,4,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#But you can always pass your own padding if you want\n",
    "tst = ConvLayer(16, 32, padding=0)\n",
    "test_eq(tst(x).shape, [64,32,6,6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#No bias by default for Batch NormType\n",
    "assert mods[0].bias is None\n",
    "#But can be overridden with `bias=True`\n",
    "tst = ConvLayer(16, 32, bias=True)\n",
    "assert first(tst.children()).bias is not None\n",
    "#For no norm, or spectral/weight, bias is True by default\n",
    "for t in [None, NormType.Spectral, NormType.Weight]:\n",
    "    tst = ConvLayer(16, 32, norm_type=t)\n",
    "    assert first(tst.children()).bias is not None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Various n_dim/tranpose\n",
    "tst = ConvLayer(16, 32, ndim=3)\n",
    "assert isinstance(list(tst.children())[0], nn.Conv3d)\n",
    "tst = ConvLayer(16, 32, ndim=1, transpose=True)\n",
    "assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#No activation/leaky\n",
    "tst = ConvLayer(16, 32, ndim=3, act_cls=None)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 2)\n",
    "tst = ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 3)\n",
    "assert isinstance(mods[2], nn.LeakyReLU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #export\n",
    "# def linear(in_features, out_features, bias=True, act_cls=None, init='auto'):\n",
    "#     \"Linear layer followed by optional activation, with optional auto-init\"\n",
    "#     res = nn.Linear(in_features, out_features, bias=bias)\n",
    "#     if act_cls: act_cls = act_cls()\n",
    "#     init_linear(res, act_cls, init=init)\n",
    "#     if act_cls: res = nn.Sequential(res, act_cls)\n",
    "#     return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #export\n",
    "# @delegates(ConvLayer)\n",
    "# def conv1d(ni, nf, ks, stride=1, ndim=1, norm_type=None, **kwargs):\n",
    "#     \"Convolutional layer followed by optional activation, with optional auto-init\"\n",
    "#     return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #export\n",
    "# @delegates(ConvLayer)\n",
    "# def conv2d(ni, nf, ks, stride=1, ndim=2, norm_type=None, **kwargs):\n",
    "#     \"Convolutional layer followed by optional activation, with optional auto-init\"\n",
    "#     return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #export\n",
    "# @delegates(ConvLayer)\n",
    "# def conv3d(ni, nf, ks, stride=1, ndim=3, norm_type=None, **kwargs):\n",
    "#     \"Convolutional layer followed by optional activation, with optional auto-init\"\n",
    "#     return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def AdaptiveAvgPool(sz=1, ndim=2):\n",
    "    \"nn.AdaptiveAvgPool layer for `ndim`\"\n",
    "    assert 1 <= ndim <= 3\n",
    "    return getattr(nn, f\"AdaptiveAvgPool{ndim}d\")(sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):\n",
    "    \"nn.MaxPool layer for `ndim`\"\n",
    "    assert 1 <= ndim <= 3\n",
    "    return getattr(nn, f\"MaxPool{ndim}d\")(ks, stride=stride, padding=padding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):\n",
    "    \"nn.AvgPool layer for `ndim`\"\n",
    "    assert 1 <= ndim <= 3\n",
    "    return getattr(nn, f\"AvgPool{ndim}d\")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def trunc_normal_(x, mean=0., std=1.):\n",
    "    \"Truncated normal initialization (approximation)\"\n",
    "    # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12\n",
    "    return x.normal_().fmod_(2).mul_(std).add_(mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class Embedding(nn.Embedding):\n",
    "    \"Embedding layer with truncated normal initialization\"\n",
    "    def __init__(self, ni, nf, std=0.01):\n",
    "        super().__init__(ni, nf)\n",
    "        trunc_normal_(self.weight.data, std=std)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation `std`, the bounds are roughly `-2*std`, `2*std`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "std = 0.02\n",
    "tst = Embedding(10, 30, std)\n",
    "assert tst.weight.min() > -2*std\n",
    "assert tst.weight.max() < 2*std\n",
    "test_close(tst.weight.mean(), 0, 1e-2)\n",
    "test_close(tst.weight.std(), std, 0.1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Self attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class SelfAttention(Module):\n",
    "    \"Self attention layer for `n_channels`.\"\n",
    "    def __init__(self, n_channels):\n",
    "        self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)]\n",
    "        self.gamma = nn.Parameter(tensor([0.]))\n",
    "\n",
    "    def _conv(self,n_in,n_out):\n",
    "        return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #Notation from the paper.\n",
    "        size = x.size()\n",
    "        x = x.view(*size[:2],-1)\n",
    "        f,g,h = self.query(x),self.key(x),self.value(x)\n",
    "        beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)\n",
    "        o = self.gamma * torch.bmm(h, beta) + x\n",
    "        return o.view(*size).contiguous()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Self-attention layer as introduced in [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318).\n",
    "\n",
    "Initially, no change is done to the input. This is controlled by a trainable parameter named `gamma` as we return `x + gamma * out`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = SelfAttention(16)\n",
    "x = torch.randn(32, 16, 8, 8)\n",
    "test_eq(tst(x),x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then during training `gamma` will probably change since it's a trainable parameter. Let's see what's happening when it gets a nonzero value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst.gamma.data.fill_(1.)\n",
    "y = tst(x)\n",
    "test_eq(y.shape, [32,16,8,8])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The attention mechanism requires three matrix multiplications (here represented by 1x1 convs). The multiplications are done on the channel level (the second dimension in our tensor) and we flatten the feature map (which is 8x8 here). As in the paper, we note `f`, `g` and `h` the results of those multiplications."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "q,k,v = tst.query[0].weight.data,tst.key[0].weight.data,tst.value[0].weight.data\n",
    "test_eq([q.shape, k.shape, v.shape], [[2, 16, 1], [2, 16, 1], [16, 16, 1]])\n",
    "f,g,h = map(lambda m: x.view(32, 16, 64).transpose(1,2) @ m.squeeze().t(), [q,k,v])\n",
    "test_eq([f.shape, g.shape, h.shape], [[32,64,2], [32,64,2], [32,64,16]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The key part of the attention layer is to compute attention weights for each of our location in the feature map (here 8x8 = 64). Those are positive numbers that sum to 1 and tell the model to pay attention to this or that part of the picture. We make the product of `f` and the transpose of `g` (to get something of size bs by 64 by 64) then apply a softmax on the first dimension (to get the positive numbers that sum up to 1). The result can then be multiplied with `h` transposed to get an output of size bs by channels by 64, which we can then be viewed as an output the same size as the original input. \n",
    "\n",
    "The final result is then `x + gamma * out` as we saw before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta = F.softmax(torch.bmm(f, g.transpose(1,2)), dim=1)\n",
    "test_eq(beta.shape, [32, 64, 64])\n",
    "out = torch.bmm(h.transpose(1,2), beta)\n",
    "test_eq(out.shape, [32, 16, 64])\n",
    "test_close(y, x + out.view(32, 16, 8, 8), eps=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class PooledSelfAttention2d(Module):\n",
    "    \"Pooled self attention layer for 2d.\"\n",
    "    def __init__(self, n_channels):\n",
    "        self.n_channels = n_channels\n",
    "        self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)]\n",
    "        self.out   = self._conv(n_channels//2, n_channels)\n",
    "        self.gamma = nn.Parameter(tensor([0.]))\n",
    "\n",
    "    def _conv(self,n_in,n_out):\n",
    "        return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        n_ftrs = x.shape[2]*x.shape[3]\n",
    "        f = self.query(x).view(-1, self.n_channels//8, n_ftrs)\n",
    "        g = F.max_pool2d(self.key(x),   [2,2]).view(-1, self.n_channels//8, n_ftrs//4)\n",
    "        h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4)\n",
    "        beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1)\n",
    "        o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3]))\n",
    "        return self.gamma * o + x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Self-attention layer used in the [Big GAN paper](https://arxiv.org/abs/1809.11096).\n",
    "\n",
    "It uses the same attention as in `SelfAttention` but adds a max pooling of stride 2 before computing the matrices `g` and `h`: the attention is ported on one of the 2x2 max-pooled window, not the whole feature map. There is also a final matrix product added at the end to the output, before retuning `gamma * out + x`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):\n",
    "    \"Create and initialize a `nn.Conv1d` layer with spectral normalization.\"\n",
    "    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)\n",
    "    nn.init.kaiming_normal_(conv.weight)\n",
    "    if bias: conv.bias.data.zero_()\n",
    "    return spectral_norm(conv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class SimpleSelfAttention(Module):\n",
    "    def __init__(self, n_in:int, ks=1, sym=False):\n",
    "        self.sym,self.n_in = sym,n_in\n",
    "        self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False)\n",
    "        self.gamma = nn.Parameter(tensor([0.]))\n",
    "\n",
    "    def forward(self,x):\n",
    "        if self.sym:\n",
    "            c = self.conv.weight.view(self.n_in,self.n_in)\n",
    "            c = (c + c.t())/2\n",
    "            self.conv.weight = c.view(self.n_in,self.n_in,1)\n",
    "\n",
    "        size = x.size()\n",
    "        x = x.view(*size[:2],-1)\n",
    "\n",
    "        convx = self.conv(x)\n",
    "        xxT = torch.bmm(x,x.permute(0,2,1).contiguous())\n",
    "        o = torch.bmm(xxT, convx)\n",
    "        o = self.gamma * o + x\n",
    "        return o.view(*size).contiguous()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PixelShuffle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PixelShuffle introduced in [this article](https://arxiv.org/pdf/1609.05158.pdf) to avoid checkerboard artifacts when upsampling images. If we want an output with `ch_out` filters, we use a convolution with `ch_out * (r**2)` filters, where `r` is the upsampling factor. Then we reorganize those filters like in the picture below:\n",
    "\n",
    "<img src=\"images/pixelshuffle.png\" alt=\"Pixelshuffle\" width=\"800\" />"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):\n",
    "    \"ICNR init of `x`, with `scale` and `init` function\"\n",
    "    ni,nf,h,w = x.shape\n",
    "    ni2 = int(ni/(scale**2))\n",
    "    k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)\n",
    "    k = k.contiguous().view(ni2, nf, -1)\n",
    "    k = k.repeat(1, 1, scale**2)\n",
    "    return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ICNR init was introduced in [this article](https://arxiv.org/abs/1707.02937). It suggests to initialize the convolution that will be used in PixelShuffle so that each of the `r**2` channels get the same weight (so that in the picture above, the 9 colors in a 3 by 3 window are initially the same).\n",
    "\n",
    "> Note: This is done on the first dimension because PyTorch stores the weights of a convolutional layer in this format: `ch_out x ch_in x ks x ks`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = torch.randn(16*4, 32, 1, 1)\n",
    "tst = icnr_init(tst)\n",
    "for i in range(0,16*4,4):\n",
    "    test_eq(tst[i],tst[i+1])\n",
    "    test_eq(tst[i],tst[i+2])\n",
    "    test_eq(tst[i],tst[i+3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class PixelShuffle_ICNR(nn.Sequential):\n",
    "    \"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`.\"\n",
    "    def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation):\n",
    "        super().__init__()\n",
    "        nf = ifnone(nf, ni)\n",
    "        layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls, bias_std=0),\n",
    "                  nn.PixelShuffle(scale)]\n",
    "        if norm_type == NormType.Weight:\n",
    "            layers[0][0].weight_v.data.copy_(icnr_init(layers[0][0].weight_v.data))\n",
    "            layers[0][0].weight_g.data.copy_(((layers[0][0].weight_v.data**2).sum(dim=[1,2,3])**0.5)[:,None,None,None])\n",
    "        else:\n",
    "            layers[0][0].weight.data.copy_(icnr_init(layers[0][0].weight.data))\n",
    "        if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)]\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The convolutional layer is initialized with `icnr_init` and passed `act_cls` and `norm_type` (the default of weight normalization seemed to be what's best for super-resolution problems, in our experiments). \n",
    "\n",
    "The `blur` option comes from [Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts](https://arxiv.org/abs/1806.02658) where the authors add a little bit of blur to completely get rid of checkerboard artifacts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "psfl = PixelShuffle_ICNR(16)\n",
    "x = torch.randn(64, 16, 8, 8)\n",
    "y = psfl(x)\n",
    "test_eq(y.shape, [64, 16, 16, 16])\n",
    "#ICNR init makes every 2x2 window (stride 2) have the same elements\n",
    "for i in range(0,16,2):\n",
    "    for j in range(0,16,2):\n",
    "        test_eq(y[:,:,i,j],y[:,:,i+1,j])\n",
    "        test_eq(y[:,:,i,j],y[:,:,i  ,j+1])\n",
    "        test_eq(y[:,:,i,j],y[:,:,i+1,j+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "psfl = PixelShuffle_ICNR(16, norm_type=None)\n",
    "x = torch.randn(64, 16, 8, 8)\n",
    "y = psfl(x)\n",
    "test_eq(y.shape, [64, 16, 16, 16])\n",
    "#ICNR init makes every 2x2 window (stride 2) have the same elements\n",
    "for i in range(0,16,2):\n",
    "    for j in range(0,16,2):\n",
    "        test_eq(y[:,:,i,j],y[:,:,i+1,j])\n",
    "        test_eq(y[:,:,i,j],y[:,:,i  ,j+1])\n",
    "        test_eq(y[:,:,i,j],y[:,:,i+1,j+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "psfl = PixelShuffle_ICNR(16, norm_type=NormType.Spectral)\n",
    "x = torch.randn(64, 16, 8, 8)\n",
    "y = psfl(x)\n",
    "test_eq(y.shape, [64, 16, 16, 16])\n",
    "#ICNR init makes every 2x2 window (stride 2) have the same elements\n",
    "for i in range(0,16,2):\n",
    "    for j in range(0,16,2):\n",
    "        test_eq(y[:,:,i,j],y[:,:,i+1,j])\n",
    "        test_eq(y[:,:,i,j],y[:,:,i  ,j+1])\n",
    "        test_eq(y[:,:,i,j],y[:,:,i+1,j+1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sequential extensions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def sequential(*args):\n",
    "    \"Create an `nn.Sequential`, wrapping items with `Lambda` if needed\"\n",
    "    if len(args) != 1 or not isinstance(args[0], OrderedDict):\n",
    "        args = list(args)\n",
    "        for i,o in enumerate(args):\n",
    "            if not isinstance(o,nn.Module): args[i] = Lambda(o)\n",
    "    return nn.Sequential(*args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class SequentialEx(Module):\n",
    "    \"Like `nn.Sequential`, but with ModuleList semantics, and can access module input\"\n",
    "    def __init__(self, *layers): self.layers = nn.ModuleList(layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        res = x\n",
    "        for l in self.layers:\n",
    "            res.orig = x\n",
    "            nres = l(res)\n",
    "            # We have to remove res.orig to avoid hanging refs and therefore memory leaks\n",
    "            res.orig, nres.orig = None, None\n",
    "            res = nres\n",
    "        return res\n",
    "\n",
    "    def __getitem__(self,i): return self.layers[i]\n",
    "    def append(self,l):      return self.layers.append(l)\n",
    "    def extend(self,l):      return self.layers.extend(l)\n",
    "    def insert(self,i,l):    return self.layers.insert(i,l)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is useful to write layers that require to remember the input (like a resnet block) in a sequential way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class MergeLayer(Module):\n",
    "    \"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`.\"\n",
    "    def __init__(self, dense:bool=False): self.dense=dense\n",
    "    def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_block = SequentialEx(ConvLayer(16, 16), ConvLayer(16,16))\n",
    "res_block.append(MergeLayer()) # just to test append - normally it would be in init params\n",
    "x = torch.randn(32, 16, 8, 8)\n",
    "y = res_block(x)\n",
    "test_eq(y.shape, [32, 16, 8, 8])\n",
    "test_eq(y, x + res_block[1](res_block[0](x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = TensorBase(torch.randn(32, 16, 8, 8))\n",
    "y = res_block(x)\n",
    "test_is(y.orig, None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Equivalent to keras.layers.Concatenate, it will concat the outputs of a ModuleList over a given dimension (default the filter dimension)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "class Cat(nn.ModuleList):\n",
    "    \"Concatenate layers outputs over a given dim\"\n",
    "    def __init__(self, layers, dim=1):\n",
    "        self.dim=dim\n",
    "        super().__init__(layers)\n",
    "    def forward(self, x): return torch.cat([l(x) for l in self], dim=self.dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = [ConvLayer(2,4), ConvLayer(2,4), ConvLayer(2,4)] \n",
    "x = torch.rand(1,2,8,8) \n",
    "cat = Cat(layers) \n",
    "test_eq(cat(x).shape, [1,12,8,8]) \n",
    "test_eq(cat(x), torch.cat([l(x) for l in layers], dim=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ready-to-go models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class SimpleCNN(nn.Sequential):\n",
    "    \"Create a simple CNN with `filters`.\"\n",
    "    def __init__(self, filters, kernel_szs=None, strides=None, bn=True):\n",
    "        nl = len(filters)-1\n",
    "        kernel_szs = ifnone(kernel_szs, [3]*nl)\n",
    "        strides    = ifnone(strides   , [2]*nl)\n",
    "        layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i],\n",
    "                  norm_type=(NormType.Batch if bn and i<nl-1 else None)) for i in range(nl)]\n",
    "        layers.append(PoolFlatten())\n",
    "        super().__init__(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is a succession of convolutional layers from `(filters[0],filters[1])` to `(filters[n-2],filters[n-1])` (if `n` is the length of the `filters` list) followed by a `PoolFlatten`. `kernel_szs` and `strides` defaults to a list of 3s and a list of 2s. If `bn=True` the convolutional layers are successions of conv-relu-batchnorm, otherwise conv-relu."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = SimpleCNN([8,16,32])\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 3)\n",
    "test_eq([[m[0].in_channels, m[0].out_channels] for m in mods[:2]], [[8,16], [16,32]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test kernel sizes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = SimpleCNN([8,16,32], kernel_szs=[1,3])\n",
    "mods = list(tst.children())\n",
    "test_eq([m[0].kernel_size for m in mods[:2]], [(1,1), (3,3)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test strides"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = SimpleCNN([8,16,32], strides=[1,2])\n",
    "mods = list(tst.children())\n",
    "test_eq([m[0].stride for m in mods[:2]], [(1,1),(2,2)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ProdLayer(Module):\n",
    "    \"Merge a shortcut with the result of the module by multiplying them.\"\n",
    "    def forward(self, x): return x * x.orig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "inplace_relu = partial(nn.ReLU, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def SEModule(ch, reduction, act_cls=defaults.activation):\n",
    "    nf = math.ceil(ch//reduction/8)*8\n",
    "    return SequentialEx(nn.AdaptiveAvgPool2d(1),\n",
    "                        ConvLayer(ch, nf, ks=1, norm_type=None, act_cls=act_cls),\n",
    "                        ConvLayer(nf, ch, ks=1, norm_type=None, act_cls=nn.Sigmoid),\n",
    "                        ProdLayer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ResBlock(Module):\n",
    "    \"Resnet block from `ni` to `nh` with `stride`\"\n",
    "    @delegates(ConvLayer.__init__)\n",
    "    def __init__(self, expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,\n",
    "                 sa=False, sym=False, norm_type=NormType.Batch, act_cls=defaults.activation, ndim=2, ks=3,\n",
    "                 pool=AvgPool, pool_first=True, **kwargs):\n",
    "        norm2 = (NormType.BatchZero if norm_type==NormType.Batch else\n",
    "                 NormType.InstanceZero if norm_type==NormType.Instance else norm_type)\n",
    "        if nh2 is None: nh2 = nf\n",
    "        if nh1 is None: nh1 = nh2\n",
    "        nf,ni = nf*expansion,ni*expansion\n",
    "        k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)\n",
    "        k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)\n",
    "        convpath  = [ConvLayer(ni,  nh2, ks, stride=stride, groups=ni if dw else groups, **k0),\n",
    "                     ConvLayer(nh2,  nf, ks, groups=g2, **k1)\n",
    "        ] if expansion == 1 else [\n",
    "                     ConvLayer(ni,  nh1, 1, **k0),\n",
    "                     ConvLayer(nh1, nh2, ks, stride=stride, groups=nh1 if dw else groups, **k0),\n",
    "                     ConvLayer(nh2,  nf, 1, groups=g2, **k1)]\n",
    "        if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))\n",
    "        if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))\n",
    "        self.convpath = nn.Sequential(*convpath)\n",
    "        idpath = []\n",
    "        if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))\n",
    "        if stride!=1: idpath.insert((1,0)[pool_first], pool(stride, ndim=ndim, ceil_mode=True))\n",
    "        self.idpath = nn.Sequential(*idpath)\n",
    "        self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()\n",
    "\n",
    "    def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a resnet block (normal or bottleneck depending on `expansion`, 1 for the normal block and 4 for the traditional bottleneck) that implements the tweaks from [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/abs/1812.01187). In particular, the last batchnorm layer (if that is the selected `norm_type`) is initialized with a weight (or gamma) of zero to facilitate the flow from the beginning to the end of the network. It also implements optional [Squeeze and Excitation](https://arxiv.org/abs/1709.01507) and grouped convs for [ResNeXT](https://arxiv.org/abs/1611.05431) and similar models (use `dw=True` for depthwise convs).\n",
    "\n",
    "The `kwargs` are passed to `ConvLayer` along with `norm_type`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def SEBlock(expansion, ni, nf, groups=1, reduction=16, stride=1, **kwargs):\n",
    "    return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh1=nf*2, nh2=nf*expansion, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def SEResNeXtBlock(expansion, ni, nf, groups=32, reduction=16, stride=1, base_width=4, **kwargs):\n",
    "    w = math.floor(nf * (base_width / 64)) * groups\n",
    "    return ResBlock(expansion, ni, nf, stride=stride, groups=groups, reduction=reduction, nh2=w, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def SeparableBlock(expansion, ni, nf, reduction=16, stride=1, base_width=4, **kwargs):\n",
    "    return ResBlock(expansion, ni, nf, stride=stride, reduction=reduction, nh2=nf*2, dw=True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Time Distributed Layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Equivalent to Keras `TimeDistributed` Layer, enables computing pytorch `Module` over an axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _stack_tups(tuples, stack_dim=1):\n",
    "    \"Stack tuple of tensors along `stack_dim`\"\n",
    "    return tuple(torch.stack([t[i] for t in tuples], dim=stack_dim) for i in range_of(tuples[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TimeDistributed(Module):\n",
    "    \"Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time.\" \n",
    "    def __init__(self, module, low_mem=False, tdim=1):\n",
    "        store_attr()\n",
    "        \n",
    "    def forward(self, *tensors, **kwargs):\n",
    "        \"input x with shape:(bs,seq_len,channels,width,height)\"\n",
    "        if self.low_mem or self.tdim!=1: \n",
    "            return self.low_mem_forward(*tensors, **kwargs)\n",
    "        else:\n",
    "            #only support tdim=1\n",
    "            inp_shape = tensors[0].shape\n",
    "            bs, seq_len = inp_shape[0], inp_shape[1]   \n",
    "            out = self.module(*[x.view(bs*seq_len, *x.shape[2:]) for x in tensors], **kwargs)\n",
    "        return self.format_output(out, bs, seq_len)\n",
    "    \n",
    "    def low_mem_forward(self, *tensors, **kwargs):                                           \n",
    "        \"input x with shape:(bs,seq_len,channels,width,height)\"\n",
    "        seq_len = tensors[0].shape[self.tdim]\n",
    "        args_split = [torch.unbind(x, dim=self.tdim) for x in tensors]\n",
    "        out = []\n",
    "        for i in range(seq_len):\n",
    "            out.append(self.module(*[args[i] for args in args_split]), **kwargs)\n",
    "        if isinstance(out[0], tuple):\n",
    "            return _stack_tups(out, stack_dim=self.tdim)\n",
    "        return torch.stack(out, dim=self.tdim)\n",
    "    \n",
    "    def format_output(self, out, bs, seq_len):\n",
    "        \"unstack from batchsize outputs\"\n",
    "        if isinstance(out, tuple):\n",
    "            return tuple(out_i.view(bs, seq_len, *out_i.shape[1:]) for out_i in out)\n",
    "        return out.view(bs, seq_len,*out.shape[1:])\n",
    "    \n",
    "    def __repr__(self):\n",
    "        return f'TimeDistributed({self.module})'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, seq_len = 2, 5\n",
    "x, y = torch.rand(bs,seq_len,3,2,2), torch.rand(bs,seq_len,3,2,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tconv = TimeDistributed(nn.Conv2d(3,4,1))\n",
    "test_eq(tconv(x).shape, (2,5,4,2,2))\n",
    "tconv.low_mem=True\n",
    "test_eq(tconv(x).shape, (2,5,4,2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mod(Module):\n",
    "    def __init__(self):\n",
    "        self.conv = nn.Conv2d(3,4,1)\n",
    "    def forward(self, x, y):\n",
    "        return self.conv(x) + self.conv(y)\n",
    "tmod = TimeDistributed(Mod())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = tmod(x,y)\n",
    "test_eq(out.shape, (2,5,4,2,2))\n",
    "tmod.low_mem=True\n",
    "out_low_mem = tmod(x,y)\n",
    "test_eq(out_low_mem.shape, (2,5,4,2,2))\n",
    "test_eq(out, out_low_mem)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mod2(Module):\n",
    "    def __init__(self):\n",
    "        self.conv = nn.Conv2d(3,4,1)\n",
    "    def forward(self, x, y):\n",
    "        return self.conv(x), self.conv(y)\n",
    "tmod2 = TimeDistributed(Mod2())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out = tmod2(x,y)\n",
    "test_eq(len(out), 2)\n",
    "test_eq(out[0].shape, (2,5,4,2,2))\n",
    "tmod2.low_mem=True\n",
    "out_low_mem = tmod2(x,y)\n",
    "test_eq(out_low_mem[0].shape, (2,5,4,2,2))\n",
    "test_eq(out, out_low_mem)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<h2 id=\"TimeDistributed\" class=\"doc_header\"><code>class</code> <code>TimeDistributed</code><a href=\"\" class=\"source_link\" style=\"float:right\">[source]</a></h2>\n",
       "\n",
       "> <code>TimeDistributed</code>(**`module`**, **`low_mem`**=*`False`*, **`tdim`**=*`1`*) :: [`Module`](/torch_core.html#Module)\n",
       "\n",
       "Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "show_doc(TimeDistributed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This module is equivalent to [Keras TimeDistributed Layer](https://keras.io/api/layers/recurrent_layers/time_distributed/). This wrapper allows to apply a layer to every temporal slice of an input. By default it is assumed the time axis (`tdim`) is the 1st one (the one after the batch size). A typical usage would be to encode a sequence of images using an image encoder.\n",
    "\n",
    "The `forward` function of `TimeDistributed` supports `*args` and `**kkwargs` but only `args` will be split and passed to the underlying module independently for each timestep, `kwargs` will be passed as they are. This is useful when you have module that take multiple arguments as inputs, this way, you can put all tensors you need spliting as `args` and other arguments that don't need split as `kwargs`.\n",
    "\n",
    "> This module is heavy on memory, as it will try to pass mutiple timesteps at the same time on the batch dimension, if you get out of memorey errors, try first reducing your batch size by the number of timesteps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = create_body(resnet18)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A resnet18 will encode a feature map of 512 channels. Height and Width will be divided by 32."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_resnet = TimeDistributed(encoder)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "a synthetic batch of 2 image-sequences of lenght 5. `(bs, seq_len, ch, w, h)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_sequence = torch.rand(2, 5, 3, 64, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jhoward/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /opt/conda/conda-bld/pytorch_1623448278899/work/c10/core/TensorImpl.h:1156.)\n",
      "  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 5, 512, 2, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "time_resnet(image_sequence).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This way, one can encode a sequence of images on feature space.\n",
    "There is also a `low_mem_forward` that will pass images one at a time to reduce GPU memory consumption."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 5, 512, 2, 2])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "time_resnet.low_mem_forward(image_sequence).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Swish and Mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from torch.jit import script"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@script\n",
    "def _swish_jit_fwd(x): return x.mul(torch.sigmoid(x))\n",
    "\n",
    "@script\n",
    "def _swish_jit_bwd(x, grad_output):\n",
    "    x_sigmoid = torch.sigmoid(x)\n",
    "    return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))\n",
    "\n",
    "class _SwishJitAutoFn(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        ctx.save_for_backward(x)\n",
    "        return _swish_jit_fwd(x)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        x = ctx.saved_variables[0]\n",
    "        return _swish_jit_bwd(x, grad_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def swish(x, inplace=False): return _SwishJitAutoFn.apply(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Swish(Module):\n",
    "    def forward(self, x): return _SwishJitAutoFn.apply(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@script\n",
    "def _mish_jit_fwd(x): return x.mul(torch.tanh(F.softplus(x)))\n",
    "\n",
    "@script\n",
    "def _mish_jit_bwd(x, grad_output):\n",
    "    x_sigmoid = torch.sigmoid(x)\n",
    "    x_tanh_sp = F.softplus(x).tanh()\n",
    "    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))\n",
    "\n",
    "class MishJitAutoFn(torch.autograd.Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, x):\n",
    "        ctx.save_for_backward(x)\n",
    "        return _mish_jit_fwd(x)\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        x = ctx.saved_variables[0]\n",
    "        return _mish_jit_bwd(x, grad_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def mish(x): return F.mish(x) if torch.__version__ >= '1.9' else MishJitAutoFn.apply(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Mish(Module):\n",
    "    def forward(self, x): return MishJitAutoFn.apply(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "if torch.__version__ >= '1.9': Mish = nn.Mish"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "for o in swish,Swish,mish,Mish: o.__default_init__ = kaiming_uniform_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Helper functions for submodules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It's easy to get the list of all parameters of a given model. For when you want all submodules (like linear/conv layers) without forgetting lone parameters, the following class wraps those in fake modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class ParameterModule(Module):\n",
    "    \"Register a lone parameter `p` in a module.\"\n",
    "    def __init__(self, p): self.val = p\n",
    "    def forward(self, x): return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def children_and_parameters(m):\n",
    "    \"Return the children of `m` and its direct parameters not registered in modules.\"\n",
    "    children = list(m.children())\n",
    "    children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])\n",
    "    for p in m.parameters():\n",
    "        if id(p) not in children_p: children.append(ParameterModule(p))\n",
    "    return children"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TstModule(Module):\n",
    "    def __init__(self): self.a,self.lin = nn.Parameter(torch.randn(1)),nn.Linear(5,10)\n",
    "\n",
    "tst = TstModule()\n",
    "children = children_and_parameters(tst)\n",
    "test_eq(len(children), 2)\n",
    "test_eq(children[0], tst.lin)\n",
    "assert isinstance(children[1], ParameterModule)\n",
    "test_eq(children[1].val, tst.a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def has_children(m):\n",
    "    try: next(m.children())\n",
    "    except StopIteration: return False\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class A(Module): pass\n",
    "assert not has_children(A())\n",
    "assert has_children(TstModule())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def flatten_model(m):\n",
    "    \"Return the list of all submodules and parameters of `m`\"\n",
    "    return sum(map(flatten_model,children_and_parameters(m)),[]) if has_children(m) else [m]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = nn.Sequential(TstModule(), TstModule())\n",
    "children = flatten_model(tst)\n",
    "test_eq(len(children), 4)\n",
    "assert isinstance(children[1], ParameterModule)\n",
    "assert isinstance(children[3], ParameterModule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class NoneReduce():\n",
    "    \"A context manager to evaluate `loss_func` with none reduce.\"\n",
    "    def __init__(self, loss_func): self.loss_func,self.old_red = loss_func,None\n",
    "\n",
    "    def __enter__(self):\n",
    "        if hasattr(self.loss_func, 'reduction'):\n",
    "            self.old_red = self.loss_func.reduction\n",
    "            self.loss_func.reduction = 'none'\n",
    "            return self.loss_func\n",
    "        else: return partial(self.loss_func, reduction='none')\n",
    "\n",
    "    def __exit__(self, type, value, traceback):\n",
    "        if self.old_red is not None: self.loss_func.reduction = self.old_red"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x,y = torch.randn(5),torch.randn(5)\n",
    "loss_fn = nn.MSELoss()\n",
    "with NoneReduce(loss_fn) as loss_func:\n",
    "    loss = loss_func(x,y)\n",
    "test_eq(loss.shape, [5])\n",
    "test_eq(loss_fn.reduction, 'mean')\n",
    "\n",
    "loss_fn = F.mse_loss\n",
    "with NoneReduce(loss_fn) as loss_func:\n",
    "    loss = loss_func(x,y)\n",
    "test_eq(loss.shape, [5])\n",
    "test_eq(loss_fn, F.mse_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def in_channels(m):\n",
    "    \"Return the shape of the first weight layer in `m`.\"\n",
    "    for l in flatten_model(m):\n",
    "        if getattr(l, 'weight', None) is not None and l.weight.ndim==4:\n",
    "            return l.weight.shape[1]\n",
    "    raise Exception('No weight layer')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_eq(in_channels(nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))), 5)\n",
    "test_eq(in_channels(nn.Sequential(nn.AvgPool2d(4), nn.Conv2d(4,3,3))), 4)\n",
    "test_eq(in_channels(nn.Sequential(BatchNorm(4), nn.Conv2d(4,3,3))), 4)\n",
    "test_eq(in_channels(nn.Sequential(InstanceNorm(4), nn.Conv2d(4,3,3))), 4)\n",
    "test_eq(in_channels(nn.Sequential(InstanceNorm(4, affine=False), nn.Conv2d(4,3,3))), 4)\n",
    "test_fail(lambda : in_channels(nn.Sequential(nn.AvgPool2d(4))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 01a_losses.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 10b_tutorial.albumentations.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_callback.core.ipynb.\n",
      "Converted 13a_learner.ipynb.\n",
      "Converted 13b_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 18a_callback.training.ipynb.\n",
      "Converted 18b_callback.preds.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.vision.ipynb.\n",
      "Converted 24_tutorial.image_sequence.ipynb.\n",
      "Converted 24_tutorial.siamese.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.text.ipynb.\n",
      "Converted 39_tutorial.transformers.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 44_tutorial.tabular.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 46_tutorial.collab.ipynb.\n",
      "Converted 50_tutorial.datablock.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 61_tutorial.medical_imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 72_callback.neptune.ipynb.\n",
      "Converted 73_callback.captum.ipynb.\n",
      "Converted 74_callback.azureml.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted 99_pytorch_doc.ipynb.\n",
      "Converted dev-setup.ipynb.\n",
      "Converted app_examples.ipynb.\n",
      "Converted camvid.ipynb.\n",
      "Converted migrating_catalyst.ipynb.\n",
      "Converted migrating_ignite.ipynb.\n",
      "Converted migrating_lightning.ipynb.\n",
      "Converted migrating_pytorch.ipynb.\n",
      "Converted ulmfit.ipynb.\n",
      "Converted index.ipynb.\n",
      "Converted quick_start.ipynb.\n",
      "Converted tutorial.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
